# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# File Name          : _simulation_functions.R
# Programmer Name    : Luis Campos
#                     luiscampos@g.harvard.edu
#
# Purpose            : This file contains simulation functions
#
# Input              : None
# Output             : None
# Usage              : 
# 
# References         : None
#
#
# Platform           : R
# Version            : v3.3.0
# Date               : 
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Replication Information:
#   We do expect differences when running this code on different systems 
#   (and with different seeds), but the differences all tend to be quite 
#   small and the general trends remain. If you would like to replicate 
#   the exact numbers and figures used in the article, in the readme you 
#   will find the session info. Please install and load all package 
#   versions (including the R version) to ensure the exact replication. 
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

##
## Take a population and some parameters and run a simulation on it
##


library( xtable )


# gamma is for type 5, where the dependence of Y on w is gradually removed. The distributions of Y(0), Y(1), w, all remain the same. 

make.population = function( N= 10000, type=3, add.noise=TRUE, center = NULL, make.plot = TRUE, gamma = NULL) {
	
	# Generate some data (some huge population)
	dat = data.frame( Y1 = 1:N, Y0 = runif(N,1,100), w=runif(N,1,20) )
	dat$Y0 = dat$Y0 + dat$w
	dat$Y1 = dat$Y0 + 10*sqrt( (max(dat$w) - dat$w)) + rnorm(N)
	
	# sim #2 (constant tx effect)
	if ( type==2 ) {
		dat$Y1 = dat$Y0 + 30  # const tx effect 
	}
	
	# sim #3 (trying to make double-hajek fail)
	if ( type==3 ) {
		dat = data.frame( Y1 = 1:N, Y0 = runif(N,1,100), w=runif(N,1,25) )
		dat$Y0 = 20 * (6 - sqrt(dat$w)) + 5 * rnorm(N)
		dat$Y1 = dat$Y0 + 10*sqrt( (max(dat$w) - dat$w))
		#dat$Y1 = dat$Y0 + 5# *sqrt( (max(dat$w) - dat$w))
	}
	if ( type==4 ) {
		dat$Y0 = dat$Y0 * 4 /(max(dat$Y0) - min(dat$Y0)) - 2
		dat$Y1 = dat$Y1 * 4 /(max(dat$Y1) - min(dat$Y1)) - 2
	}
	if ( type==5 ) {
		dat = data.frame( Y1 = 1:N, Y0 = runif(N,1,100), u = rnorm(N))
		dat$u1 = 1 + 24*pnorm(dat$u)
		dat$Y0 = 20 * (6 - sqrt(dat$u1)) + 5 * rnorm(N)
		dat$Y1 = dat$Y0 + 10*sqrt( (max(dat$u1) - dat$u1))

		dat$w = 1 + 24*pnorm(gamma*dat$u + (1-gamma)*rnorm(N))
	}

	
	# add more noise
	if ( add.noise ) {
		dat$Y0 = dat$Y0 + 20*rnorm(N)
		dat$Y1 = dat$Y1 + 20*rnorm(N)
		dat$w = dat$w + runif(N, 0, 0.5)
	}


	# calculate actual treatment effect and normalize all the weights
	dat$t = dat$Y1 - dat$Y0
	dat$pi = 1 / dat$w 
	dat$w = mean(dat$pi) / dat$pi


	# look at it
	if(make.plot) plot( dat[ sample(nrow(dat), 1000 ), ] )
	
	tau = mean(dat$Y1) - mean(dat$Y0)
	cat( "tau = ", tau, "\n\nSummary dat:\n" )
	
	print( summary(dat) )
	cat( "\nMax weight difference: ", max(dat$w) / min(dat$w), "\n" )

	dat
}





## 
## It includes calculating bootstrap SEs with each draw.
## Warning: Fairly time-intensive to run.
## dat: The population to sample from (data frame)
## B: number simulations
## R: number bootstrap replicates
run.simulation = function( dat, n = 500, K = 5, B = 50, R = 30, bern.sample=FALSE ) {

	pb= txtProgressBar(min = 0,max = B,style = 3)
	
	reps = sapply( 1:B, function( b ) {
	
	setTxtProgressBar(pb,b)
	
		
		# get sample (and potential outcomes)
		if ( bern.sample ) {
			samp.ind = bern.sample( dat$pi )
		} else {
			samp.ind = sample( nrow(dat), n, replace=TRUE, prob=dat$pi )
		}
		
		edat = dat[ samp.ind, ] 
		
		# if we had all potential outcomes
		tau.S = mean(edat$Y1 - edat$Y0 )
		nu.S = mean( edat$w * (edat$Y1 - edat$Y0 ) )
		nu.S.h = nu.S * nrow(edat) / sum(edat$w)
		
		# apply treatment and get outcomes
		edat = randomizeDat( edat )
		
		# get various estimates
		estimates = calc.txs( edat$Yobs, edat$Tx, edat$w, K=K )
		
		# get bootstrap SEs
		if ( R > 0 ) {
			SE.frame = calc.boot.se( edat$Yobs, edat$Tx, edat$w, K=K, R=R )
			SEs = SE.frame$SE
			#	if ( any( is.na( SEs ) ) ) {
			#		browser()
			#	}	
			names(SEs) = paste("SE.", SE.frame$Var1, sep="" )	
		} else {
			SEs = rep( NA, 7 )
			names(SEs) = paste("SE.", names(estimates[1:7]), sep="" )
		}
		
		c( tau.S = tau.S, nu.S = nu.S, nu.S.h = nu.S.h, estimates, SEs )
	 } )
	close(pb)

	reps
}


## 
## Investigate performance of estimators
##
analyze.simulation = function( pop, reps, SE = TRUE) {

	tau = mean( pop$Y1 - pop$Y0 )
		
	# look at varying sample size and sample weight
	summary( reps["n",] )
	summary( reps["Z",] )
	# plot( reps["Z",] ~ reps["n",] )
	
	SEs = reps[ grep( "SE", rownames(reps) ), ]
	reps = reps[ -grep( "SE", rownames(reps) ), ]
	
	rs = melt( reps )
	head(rs)
	rs = subset( rs, Var1 != "Z" & Var1 != "n" )
	#rs$err2 = (rs$value - tau)^2
	
	# weird hack
	rs$tau = tau
	res = ddply( rs, .(Var1), summarize,
			n.na = sum( is.na( value ) ),
			mean = mean( value, na.rm=TRUE ),
			bias = mean( value - tau, na.rm=TRUE ),
			sd = sd( value, na.rm=TRUE ),
			RMSE = sqrt( mean( (value - tau)^2, na.rm=TRUE ) ) )
	
	res$mean = round( res$mean, digits=2 )
	res$bias = round( res$bias, digits=2 )
	res$sd = round( res$sd, digits=2 )
	res$RMSE = round( res$RMSE, digits=2 )
	
	cat( "Results of estimators:\n" )
	print( res )
	

	# Look at the bootstrap SEs
	if(SE){
		rs = melt( SEs )
		head(rs)
		
		res.SE = ddply( rs, .(Var1), summarize,
				n.na = sum( is.na( value ) ),
				mean = mean( value, na.rm=TRUE ),
				sd = sd( value, na.rm=TRUE ) )
		
		cat( "\nBootstrap results:\n" )
		print( res.SE )
		
		# plot( res.SE$mean, res$sd[-(1:3)], xlab="Bootstrap SE (avg)", ylab="Actual SE")
		# abline( 0, 1 )
		
	}	

		xtable( res, caption="Next Generation Simulation" )

	# Make table for paper
	res$n.na = NULL

	if(SE){
		res$boot.SE = NA
		res$boot.SE[4:12] = round( res.SE$mean, digits=2 )
	}
	
	cat( "\nFinal table for paper:\n" )
	print( res )
	
	
	
	cat( "\nLatex version of table for paper:\n" )
	print.xtable( xtable( res ), hline.after=c(3,4,7) )
	
	invisible( res )
}











panel.hist <- function(x, ...)
{
    usr <- par("usr"); on.exit(par(usr))
    par(usr = c(usr[1:2], 0, 1.5) )
    h <- hist(x, plot = FALSE, breaks = 20)
    breaks <- h$breaks; nB <- length(breaks)
    y <- h$counts; y <- y/max(y)
    rect(breaks[-nB], 0, breaks[-1], y, col = "lightgrey", ...)
}
panel.cor <- function(x, y, digits = 2, prefix = "", ...)
{
    usr <- par("usr"); on.exit(par(usr))
    par(usr = c(0, 1, 0, 1))

    r <- abs(cor(x, y))
    txt <- format(c(r, 0.123456789), digits = digits)[1]
    txt <- paste0(prefix, txt)
    # if(missing(cex.cor)) cex.cor <- 0.8/strwidth(txt)
    text(0.5, 0.5, txt, cex = 3)
}




pop_sample.plots = function(popA){

    dat = popA
    samp.ind = sample( nrow(dat), 500, replace=TRUE, prob=dat$pi )
    edat = dat[ samp.ind, ] 

    edat = randomizeDat( edat )

    edat$b = stratify( edat$w, K )


    pairs( popA[ sample(nrow(popA), 1000 ),c('Y1', 'Y0', 'w', 't')], pch = 15, cex = 0.4, labels = c(expression(Y[i](1)), expression(Y[i](0)), expression(w[i]), expression(Delta[i])), diag.panel=panel.hist, cex.axis = 1.5, lower.panel = panel.cor)
    plot( edat[,c('Y1', 'Y0', 'w', 'b')], pch = 15, cex = 0.4, labels = c(expression(Y[i](1)), expression(Y[i](0)), expression(w[i]), expression(b[i])), diag.panel=panel.hist, cex.axis = 1.5, lower.panel = panel.cor)

}